import pickle
import numpy as np
import pandas as pd
import umap
from sklearn.cluster import DBSCAN, KMeans, AgglomerativeClustering
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score
import plotly.express as px
import hdbscan
import os

# --- Configuration ---
INPUT_FILE_PATH = "embeddings_data.pkl"
OUTPUT_HTML_PATH = "cluster_visualization_modified.html"
OUTPUT_DIR = "cluster_outputs"  # Directory for saving cluster results

# --- Algorithm Selection and Parameters ---
# Choose among: 'hdbscan', 'kmeans', 'agglomerative', 'dbscan'
ALGORITHM = 'kmeans'

# UMAP parameters (for clustering)
UMAP_COMPONENTS_FOR_CLUSTERING = 500

# Parameters for each algorithm
HDBScan_params = {'min_cluster_size': 15, 'min_samples': 5}
# KMeans: provide a range to automatically find the best K
KMeans_params = {'cluster_range': range(5, 21)}  # Test cluster numbers from 5 to 20
Agglomerative_params = {'n_clusters': 10}
DBSCAN_params = {'eps': 0.3, 'min_samples': 3}


def load_data(file_path: str):
    """Load texts and vectors from a pickle file."""
    print(f"Step 1: Loading data from file {file_path}...")
    try:
        with open(file_path, 'rb') as f:
            data = pickle.load(f)
            print("Data loaded successfully!")
            return data.get("texts", []), data.get("vectors", [])
    except FileNotFoundError:
        print(f"Error: File {file_path} not found")
        return [], []


def perform_clustering(vectors: np.ndarray, algorithm: str):
    """Perform clustering based on the selected algorithm."""
    print(f"\nStep 2: Performing clustering using algorithm '{algorithm}'...")

    # Prepare data for clustering
    vectors_np = np.array(vectors)
    scaled_vectors = StandardScaler().fit_transform(vectors_np)
    reducer = umap.UMAP(n_components=UMAP_COMPONENTS_FOR_CLUSTERING, random_state=42)
    print("Reducing dimensionality for clustering...")
    reduced_vectors = reducer.fit_transform(scaled_vectors)
    print("Dimensionality reduction complete.")

    # Clustering
    if algorithm == 'kmeans':
        # --- KMeans Auto Optimization ---
        cluster_range = KMeans_params['cluster_range']
        silhouette_scores = []
        print(f"K-Means: Searching for best K value using Silhouette score within {cluster_range}...")

        for k in cluster_range:
            temp_clusterer = KMeans(n_clusters=k, random_state=42, n_init='auto')
            temp_labels = temp_clusterer.fit_predict(reduced_vectors)
            score = silhouette_score(reduced_vectors, temp_labels)
            silhouette_scores.append(score)
            print(f"  - K={k}, Silhouette Score: {score:.4f}")

        best_k_index = np.argmax(silhouette_scores)
        best_k = cluster_range[best_k_index]
        best_score = silhouette_scores[best_k_index]
        print(f"\nBest K found: {best_k} (Silhouette Score: {best_score:.4f})")
        print("Performing final clustering with best K...")
        final_clusterer = KMeans(n_clusters=best_k, random_state=42, n_init='auto')
        labels = final_clusterer.fit_predict(reduced_vectors)
    else:
        # --- Other Algorithms ---
        if algorithm == 'hdbscan':
            clusterer = hdbscan.HDBSCAN(**HDBScan_params)
        elif algorithm == 'agglomerative':
            clusterer = AgglomerativeClustering(**Agglomerative_params)
        elif algorithm == 'dbscan':
            clusterer = DBSCAN(**DBSCAN_params)
        else:
            raise ValueError("Unknown algorithm name. Choose from 'hdbscan', 'kmeans', 'agglomerative', 'dbscan'.")
        labels = clusterer.fit_predict(reduced_vectors)

    print("Clustering complete!")
    return labels


def create_and_save_plot(texts, vectors, labels):
    """Create and save an interactive visualization plot."""
    print("\nStep 3: Generating interactive visualization plot...")

    # Reduce to 2D for visualization
    reducer_viz = umap.UMAP(n_components=2, random_state=42)
    vectors_2d = reducer_viz.fit_transform(np.array(vectors))

    df = pd.DataFrame({
        'x': vectors_2d[:, 0],
        'y': vectors_2d[:, 1],
        'text': texts,
        'cluster': [str(label) for label in labels]
    })

    num_clusters = len(set(labels) - {-1})
    title_text = f"Text Clustering Visualization (Algorithm: {ALGORITHM}, {num_clusters} clusters found)"

    fig = px.scatter(
        df, x='x', y='y', color='cluster', hover_data=['text'],
        title=title_text,
        labels={'x': 'UMAP 1', 'y': 'UMAP 2'},
        category_orders={"cluster": sorted(df['cluster'].unique(), key=lambda x: int(x))}
    )
    fig.update_traces(marker=dict(size=5, opacity=0.8))
    fig.update_layout(
        legend_title_text='Cluster ID',
        title_font=dict(size=28, family='Times New Roman', color='black'),
        xaxis_title_font=dict(size=20, family='Times New Roman', color='black'),
        yaxis_title_font=dict(size=20, family='Times New Roman', color='black'),
        legend_font=dict(size=18, family='Times New Roman', color='black')
    )

    fig.write_html(OUTPUT_HTML_PATH)
    print(f"Plot successfully saved to '{OUTPUT_HTML_PATH}'")
    fig.show()


def save_clusters_to_files(texts: list, labels: np.ndarray, output_dir: str):
    """Save each cluster's content into separate text files."""
    print(f"\nStep 4: Saving cluster results to files...")

    # Create directory if it does not exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print(f"Directory created: {output_dir}")

    # Group by label
    df = pd.DataFrame({'text': texts, 'label': labels})
    unique_labels = df['label'].unique()

    for label in unique_labels:
        # Special file name for noise points (-1)
        file_name = f"cluster_{label}.txt" if label != -1 else "cluster_noise.txt"
        file_path = os.path.join(output_dir, file_name)

        # Filter texts for current cluster
        cluster_texts = df[df['label'] == label]['text'].tolist()

        # Write texts to file, one per line
        with open(file_path, 'w', encoding='utf-8') as f:
            for text in cluster_texts:
                f.write(text + "\n")

        print(f"  - Saved {len(cluster_texts)} items to {file_path}")

    print("All clusters saved successfully!")


if __name__ == "__main__":
    texts_data, vectors_data = load_data(INPUT_FILE_PATH)
    if texts_data and vectors_data:
        final_labels = perform_clustering(vectors_data, ALGORITHM)
        create_and_save_plot(texts_data, vectors_data, final_labels)
        # Save clustering results to files
        save_clusters_to_files(texts_data, final_labels, OUTPUT_DIR)
